# create a child of VisionEncoderDecoderModel class and only modify the forward method

# %%

from  datasets import Dataset, DatasetDict
from transformers import AutoTokenizer
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor, BeitFeatureExtractor, BeitForMaskedImageModeling
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor, BeitFeatureExtractor, BeitForMaskedImageModeling
from typing import Optional, Tuple, Union
import torch
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.models.bart.modeling_bart import shift_tokens_right
from transformers.models.bart.modeling_bart import BartForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
class VisionEncoderDecoderRLModel (VisionEncoderDecoderModel):
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        rl_rewards = None,
        **kwargs,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
        rewards = rl_rewards
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}

        kwargs_decoder = {
            argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
        }

        if encoder_outputs is None:
            if pixel_values is None:
                raise ValueError("You have to specify pixel_values")

            encoder_outputs = self.encoder(
                pixel_values,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs_encoder,
            )
        elif isinstance(encoder_outputs, tuple):
            encoder_outputs = BaseModelOutput(*encoder_outputs)

        encoder_hidden_states = encoder_outputs[0]

        # optionally project encoder_hidden_states
        if (
            self.encoder.config.hidden_size != self.decoder.config.hidden_size
            and self.decoder.config.cross_attention_hidden_size is None
        ):
            encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)

        # else:
        encoder_attention_mask = None

        if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
            decoder_input_ids = shift_tokens_right(
                labels, self.config.pad_token_id, self.config.decoder_start_token_id
            )

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            use_cache=use_cache,
            past_key_values=past_key_values,
            return_dict=return_dict,
            **kwargs_decoder,
        )

        # Compute loss independent from decoder (as some shift the logits inside them)
        rl_loss = None
        if labels is not None and rewards is not None:
            logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
            rl_logits = logits

            log_prob_fct = torch.nn.LogSoftmax(dim=-1)
            # valid tokens is where the label is not equal to 50256 
            valid_tokens = torch.zeros_like(labels).bool()
            valid_tokens[labels != 50256] = True

            
            

            probs = log_prob_fct(rl_logits)[valid_tokens].unsqueeze(0)
            rl_seqs = labels[valid_tokens].unsqueeze(0)             
            rl_probs = torch.gather(probs, 2, rl_seqs[:, :, None]).squeeze(-1)
            # cut down rewards to size of labels
            rewards = rewards[:, :labels.shape[1]]
            rl_rewards = rewards[valid_tokens].unsqueeze(0) 
            rl_loss = -(rl_probs*rl_rewards).mean() 
        if rl_loss is not None:
            pass
        else:
            raise NotImplementedError()
        # Compute loss independent from decoder (as some shift the logits inside them)
        loss = None
        if labels is not None:
            logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
            
            if rewards is not None:
                loss += rl_loss

        if not return_dict:
            if loss is not None:
                return (loss,) + decoder_outputs + encoder_outputs
            else:
                return decoder_outputs + encoder_outputs

        return Seq2SeqLMOutput(
            loss=loss,
            logits=decoder_outputs.logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
    


# %%

model = VisionEncoderDecoderRLModel.from_pretrained("")

# %%

from PIL import Image
from data_loader import get_base_datasets

train_dataset, val_dataset, test_dataset = get_base_datasets()



# %%


image_encoder_model = "microsoft/dit-base"
text_decode_model = "gpt2"
 
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model)
tokenizer = AutoTokenizer.from_pretrained(text_decode_model)



# %%

tokenizer.pad_token = tokenizer.eos_token

# update the model config
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

output_dir = "dit_rl_model"



# %%
import datasets

dataset = datasets.DatasetDict()
dataset["train"] = train_dataset
dataset["validation"] = val_dataset
dataset["test"] = test_dataset
ds = dataset

# %%
import pickle
with open("critic_model_gpt2_token_pred_ground_predictions.pkl", "rb") as f:
    critic_dict_loaded = pickle.load(f)
#%%

def tokenization_fn(captions, max_target_length):
    """Run tokenization on captions."""
    labels = tokenizer(captions, 
                      padding="max_length", 
                        truncation=True,
                      max_length=max_target_length).input_ids

    return labels


def feature_extraction_fn(image_paths, check_image=True):
    
    if check_image:
        images = []
        to_keep = []
        for image_file in image_paths:
            try:
                img = Image.open(image_file)
                images.append(img)
                to_keep.append(True)
            except Exception:
                to_keep.append(False)
    else:
        images = [Image.open(image_file) for image_file in image_paths]
    images = [img.resize((224, 224)) for img in images]
    images = [img.convert("RGB") for img in images]
    encoder_inputs = feature_extractor(images=images, return_tensors="np")

    return encoder_inputs.pixel_values

# %%

def get_reward_from_error_type(error_type):
    if error_type == 0:
        # Compile error
        return -1
    elif error_type  == 1:
        # Runtime error
        return -0.6
    elif error_type == 2:
        # Failed unit tests
        return -0.3
    elif error_type == 3:
        # Passed all unit tests
        return 1 
    else:
        raise NotImplementedError()

def preprocess_fn(examples, max_target_length, check_image = True):
    """Run tokenization + image feature extraction"""
    image_paths = examples['img_path']
    captions = examples['code']    
    
    model_inputs = {}
    # This contains image path column
    model_inputs['labels'] = tokenization_fn(captions, max_target_length)
    model_inputs['pixel_values'] = feature_extraction_fn(image_paths, check_image=check_image)
    rl_rewards = []
    for id in examples["id"]:
        if id not in critic_dict_loaded:
            rl_reward = [1] * 1000
            rl_rewards.append(rl_reward)
            continue
        rl_logits = critic_dict_loaded[id]["logits"]
        rl_prediction = critic_dict_loaded[id]["prediction"]
        rl_selected_logits = rl_logits[:, rl_prediction]
        rl_reward = [logit * get_reward_from_error_type(rl_prediction) for logit in rl_selected_logits]
        rl_rewards.append(rl_reward)
    model_inputs['rl_rewards'] = torch.tensor(rl_rewards)


    return model_inputs

# %%
MAX_LENGTH = 900


processed_dataset = ds.map(
    function=preprocess_fn,
    batched=True,
    fn_kwargs={"max_target_length": MAX_LENGTH},
    remove_columns=ds['train'].column_names
)
# %%

processed_dataset.set_format("torch")

# train model on processed_dataset, use pytorch loop

# freeze all base layers
# for param in model.base_model.parameters():
#     param.requires_grad = False


dataloader = torch.utils.data.DataLoader(processed_dataset["train"], batch_size=1, shuffle=True)


# %%

from transformers import default_data_collator
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments


training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    output_dir="",
    report_to="wandb",
    num_train_epochs=1
    
)


trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=feature_extractor,
    args=training_args,
    # compute_metrics=compute_metrics,
    train_dataset=processed_dataset['train'],
    eval_dataset=processed_dataset['validation'],
    data_collator=default_data_collator,
)

trainer.train()

# save model
trainer.save_model(output_dir)

# %%